import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import utils


class EmbeddingLoss(nn.Module):
    def __init__(self, th_similar_min, th_different_max):
        super(EmbeddingLoss, self).__init__()
        self.th_similar_min = th_similar_min
        self.th_different_max = th_different_max

    def cosine_similarity(self, x1, x2, eps=1e-8):
        '''
        pair-wise cosine distance
        x1: [M, D]
        x2: [N, D]
        similarity: [M, N]
        '''
        w1 = x1.norm(p=2, dim=1, keepdim=True)
        w2 = x2.norm(p=2, dim=1, keepdim=True)
        similarity = torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
        return similarity

    def forward(self, embeddings, act_agnostic, bkg_seed):
        loss_act_batch = 0
        loss_batch = 0
        num_batch = embeddings.size(0)
        num_batch_dynamic = num_batch
        num_batch_dynamic_bkg = num_batch
        for ibat in range(num_batch):
            act_agnostic_single = act_agnostic[ibat, :]
            embedding = embeddings[ibat, :, :]
            mask_act = torch.where(act_agnostic_single == 1)[0]
            fg_embedding = embedding[:, mask_act]
            fg_embedding = fg_embedding.t()  # [M, D]
            sim_fg2fg = self.cosine_similarity(fg_embedding, fg_embedding)
            sim_fg2fg_hard = torch.min(sim_fg2fg, dim=1)[0]
            zero = torch.zeros_like(sim_fg2fg_hard)
            loss_fg2fg = torch.max(self.th_similar_min - sim_fg2fg_hard, zero)
            loss_fg2fg = loss_fg2fg.mean()

            loss_act_batch = loss_act_batch + loss_fg2fg

            bkg_seed_single = bkg_seed[ibat]
            mask_bkg = torch.where(bkg_seed_single == 1)[0]
            if not mask_bkg.numel():
                # print('[]')
                num_batch_dynamic_bkg -= 1
                continue
            else:
                bg_embedding = embedding[:, mask_bkg.clone().detach()]
                bg_embedding = bg_embedding.t()

                # all bg embeddings should be similar
                sim_bg2bg = self.cosine_similarity(bg_embedding, bg_embedding)
                sim_bg2bg_hard = torch.min(sim_bg2bg, dim=1)[0]
                zero = torch.zeros_like(sim_bg2bg_hard)
                loss_bg2bg = torch.max(self.th_similar_min - sim_bg2bg_hard, zero)
                loss_bg2bg = loss_bg2bg.mean()

                # fg embeddings should be different with bg embeddings
                sim_fg2bg = self.cosine_similarity(fg_embedding, bg_embedding)

                # fg2bg
                sim_fg2bg_hard = torch.max(sim_fg2bg, dim=1)[0]
                zero = torch.zeros_like(sim_fg2bg_hard)
                loss_fg2bg = torch.max(sim_fg2bg_hard - self.th_different_max, zero)
                loss_fg2bg = loss_fg2bg.mean()

                # bg2fg
                sim_bg2fg_hard = torch.max(sim_fg2bg, dim=0)[0]
                zero = torch.zeros_like(sim_bg2fg_hard)
                loss_bg2fg = torch.max(sim_bg2fg_hard - self.th_different_max, zero)
                loss_bg2fg = loss_bg2fg.mean()

                loss_batch = loss_batch + loss_bg2bg + loss_fg2bg + loss_bg2fg



        loss = (loss_act_batch / num_batch_dynamic) + (loss_batch / num_batch_dynamic_bkg)
        # loss_batch = loss_batch / num_batch_dynamic
        return loss


class Total_loss(nn.Module):
    def __init__(self, lambdas, th_similar_min, th_different_max):
        super(Total_loss, self).__init__()
        self.tau = 0.1
        self.sampling_size = 3
        self.lambdas = lambdas
        self.ce_criterion = nn.BCELoss(reduction='none')
        self.embedding_loss = EmbeddingLoss(th_similar_min=th_similar_min, th_different_max=th_different_max)

    def _cls_loss(self, scores, labels):
        '''
        calculate classification loss
        1. dispose label, ensure the sum is 1
        2. calculate topk mean, indicates classification score
        3. calculate loss
        '''
        labels = labels / (torch.sum(labels, dim=1, keepdim=True) + 1e-10)
        clsloss = -torch.mean(torch.sum(labels * F.log_softmax(scores, dim=1), dim=1), dim=0)
        return clsloss

    def forward(self, vid_score, cas_sigmoid_fuse, output, features, stored_info, label, point_anno, step):
        loss = {}

        # print(vid_score, label)
        # loss_vid = self.ce_criterion(vid_score, label)
        # loss_vid = loss_vid.mean()

        # -----modify 2023-04-11----------
        score_base, cas_base, score_supp, cas_supp, embedding, fore_weights = output
        label_base = torch.cat((label, torch.ones((label.shape[0], 1)).cuda()), dim=1)
        label_supp = torch.cat((label, torch.zeros((label.shape[0], 1)).cuda()), dim=1)

        loss_base = self._cls_loss(score_base, label_base)
        loss_supp = self._cls_loss(score_supp, label_supp)
        # loss_norm = torch.mean(torch.norm(fore_weights, p=1, dim=2))
        loss_vid = loss_base * 0.5 + loss_supp * 0.5

        # ---------------------------

        point_anno = torch.cat((point_anno, torch.zeros((point_anno.shape[0], point_anno.shape[1], 1)).cuda()), dim=2)

        weighting_seq_act = point_anno.max(dim=2, keepdim=True)[0]
        num_actions = point_anno.max(dim=2)[0].sum(dim=1)

        focal_weight_act = (1 - cas_sigmoid_fuse) * point_anno + cas_sigmoid_fuse * (1 - point_anno)
        focal_weight_act = focal_weight_act ** 2

        loss_frame = (((focal_weight_act * self.ce_criterion(cas_sigmoid_fuse, point_anno) * weighting_seq_act).sum
            (dim=2)).sum(dim=1) / num_actions).mean()

        pseudo_instance_loss = (self.ce_criterion(cas_sigmoid_fuse, point_anno)).mean()

        act_seed, bkg_seed = utils.select_seed(cas_sigmoid_fuse.detach().cpu(), point_anno.detach().cpu())

        act_agnostic = act_seed.max(dim=2)[0]
        loss_embedding = self.embedding_loss(embedding, act_agnostic, bkg_seed)

        bkg_seed = bkg_seed.unsqueeze(-1).cuda()

        point_anno_bkg = torch.zeros_like(point_anno).cuda()
        point_anno_bkg[: ,: ,-1] = 1

        weighting_seq_bkg = bkg_seed
        num_bkg = bkg_seed.sum(dim=1)
        focal_weight_bkg = (1 - cas_sigmoid_fuse) * point_anno_bkg + cas_sigmoid_fuse * (1 - point_anno_bkg)
        focal_weight_bkg = focal_weight_bkg ** 2

        loss_frame_bkg = (((focal_weight_bkg * self.ce_criterion(cas_sigmoid_fuse, point_anno_bkg) * weighting_seq_bkg).sum(dim=2)).sum
            (dim=1) / num_bkg).mean()

        loss_score_act = 0
        loss_score_bkg = 0
        loss_feat = 0

        if len(stored_info['new_dense_anno'].shape) > 1:
            new_dense_anno = stored_info['new_dense_anno'].cuda()
            new_dense_anno = torch.cat \
                ((new_dense_anno, torch.zeros((new_dense_anno.shape[0], new_dense_anno.shape[1], 1)).cuda()), dim=2)

            act_idx_diff = new_dense_anno[: ,1:] - new_dense_anno[: ,:-1]
            loss_score_act = 0
            loss_feat = 0
            for b in range(new_dense_anno.shape[0]):
                gt_classes = torch.nonzero(label[b]).squeeze(1)
                act_count = 0
                loss_score_act_batch = 0
                loss_feat_batch = 0

                for c in gt_classes:
                    range_idx = torch.nonzero(act_idx_diff[b ,: ,c]).squeeze(1)
                    range_idx = range_idx.cpu().data.numpy().tolist()
                    if type(range_idx) is not list:
                        range_idx = [range_idx]
                    if len(range_idx) == 0:
                        continue
                    if act_idx_diff[b, range_idx[0], c] != 1:
                        range_idx = [-1] + range_idx
                    if act_idx_diff[b, range_idx[-1], c] != -1:
                        range_idx = range_idx + [act_idx_diff.shape[1] - 1]

                    label_lst = []
                    feature_lst = []

                    if range_idx[0] > -1:
                        start_bkg = 0
                        end_bkg = range_idx[0]
                        bkg_len = end_bkg - start_bkg + 1

                        label_lst.append(0)
                        feature_lst.append \
                            (utils.feature_sampling(features[b], start_bkg, end_bkg + 1, self.sampling_size))

                    for i in range(len(range_idx) // 2):
                        if range_idx[ 2 *i + 1] - range_idx[ 2 *i] < 1:
                            continue

                        label_lst.append(1)
                        feature_lst.append \
                            (utils.feature_sampling(features[b], range_idx[ 2 *i] + 1, range_idx[ 2 *i + 1] + 1, self.sampling_size))

                        if range_idx[ 2 *i + 1] != act_idx_diff.shape[1] - 1:
                            start_bkg = range_idx[ 2 *i + 1] + 1

                            if i == (len(range_idx) // 2 - 1):
                                end_bkg = act_idx_diff.shape[1] - 1
                            else:
                                end_bkg = range_idx[ 2 *i + 2]

                            bkg_len = end_bkg - start_bkg + 1

                            label_lst.append(0)
                            feature_lst.append \
                                (utils.feature_sampling(features[b], start_bkg, end_bkg + 1, self.sampling_size))

                        start_act = range_idx[ 2 *i] + 1
                        end_act = range_idx[ 2 *i + 1]

                        complete_score_act = utils.get_oic_score(cas_sigmoid_fuse[b ,: ,c], start=start_act, end=end_act)

                        loss_score_act_batch += 1 - complete_score_act

                        act_count += 1

                    if sum(label_lst) > 1:
                        feature_lst = torch.stack(feature_lst, 0).clone()
                        feature_lst = feature_lst / torch.norm(feature_lst, dim=1, p=2).unsqueeze(1)
                        label_lst = torch.tensor(label_lst).cuda().float()

                        sim_matrix = torch.matmul(feature_lst, torch.transpose(feature_lst, 0, 1)) / self.tau

                        sim_matrix = torch.exp(sim_matrix)

                        sim_matrix = sim_matrix.clone().fill_diagonal_(0)

                        scores = (sim_matrix * label_lst.unsqueeze(1)).sum(dim=0) / sim_matrix.sum(dim=0)

                        loss_feat_batch = (-label_lst * torch.log(scores)).sum() / label_lst.sum()

                if act_count > 0:
                    loss_score_act += loss_score_act_batch / act_count
                    loss_feat += loss_feat_batch


            bkg_idx_diff = (1 - new_dense_anno[: ,1:]) - (1 - new_dense_anno[: ,:-1])
            loss_score_bkg = 0
            for b in range(new_dense_anno.shape[0]):
                gt_classes = torch.nonzero(label[b]).squeeze(1)
                loss_score_bkg_batch = 0
                bkg_count = 0

                for c in gt_classes:
                    range_idx = torch.nonzero(bkg_idx_diff[b ,: ,c]).squeeze(1)
                    range_idx = range_idx.cpu().data.numpy().tolist()
                    if type(range_idx) is not list:
                        range_idx = [range_idx]
                    if len(range_idx) == 0:
                        continue
                    if bkg_idx_diff[b, range_idx[0], c] != 1:
                        range_idx = [-1] + range_idx
                    if bkg_idx_diff[b, range_idx[-1], c] != -1:
                        range_idx = range_idx + [bkg_idx_diff.shape[1] - 1]

                    for i in range(len(range_idx) // 2):
                        if range_idx[ 2 *i + 1] - range_idx[ 2 *i] < 1:
                            continue

                        start_bkg = range_idx[ 2 *i] + 1
                        end_bkg = range_idx[ 2 *i + 1]

                        complete_score_bkg = utils.get_oic_score(1 - cas_sigmoid_fuse[b ,: ,c], start=start_bkg, end=end_bkg)

                        loss_score_bkg_batch += 1 - complete_score_bkg

                        bkg_count += 1

                if bkg_count > 0:
                    loss_score_bkg += loss_score_bkg_batch / bkg_count

            loss_score_act = loss_score_act / new_dense_anno.shape[0]
            loss_score_bkg = loss_score_bkg / new_dense_anno.shape[0]

            loss_feat = loss_feat / new_dense_anno.shape[0]

        loss_score = (loss_score_act + loss_score_bkg) ** 2

        loss_total = self.lambdas[0] * loss_vid + self.lambdas[1] * loss_frame + self.lambdas[2] * loss_frame_bkg + self.lambdas[3] * loss_score + self.lambdas[4] * loss_feat + self.lambdas[5] * pseudo_instance_loss + self.lambdas[6] * loss_embedding

        loss["loss_vid"] = loss_vid
        loss["loss_frame"] = loss_frame
        loss["loss_frame_bkg"] = loss_frame_bkg
        loss["loss_score_act"] = loss_score_act
        loss["loss_score_bkg"] = loss_score_bkg
        loss["loss_score"] = loss_score
        loss["loss_feat"] = loss_feat
        loss["loss_total"] = loss_total

        return loss_total, loss